{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "02112.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/02112.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aMbueFYmoQIa",
        "colab_type": "text"
      },
      "source": [
        "## 10.8 文本情感分类：使用卷积神经网络 ( textCNN )"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ttY6o92toyi5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import os \n",
        "import torch \n",
        "from torch import nn \n",
        "import torchtext.vocab as Vocab \n",
        "import torch.utils.data as Data \n",
        "import torch.nn.functional as F \n",
        "import d2l \n",
        "import tarfile\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4bMcEOuj0jx7",
        "colab_type": "text"
      },
      "source": [
        "### 10.8.1 一维卷积层"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "F37i8Mll0qG1",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def corr1d(X, K):\n",
        "    w = K.shape[0]\n",
        "    Y = torch.zeros((X.shape[0] - w + 1))\n",
        "    for i in range(Y.shape[0]):\n",
        "        Y[i] = (X[i: i + w] * K).sum()\n",
        "    return Y "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_1eFupDp09Bc",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "c7fc260c-5fce-429c-b159-1e97ac3ac0d8"
      },
      "source": [
        "X, K = torch.tensor([0, 1, 2, 3, 4, 5, 6]), torch.tensor([1, 2])\n",
        "corr1d(X, K)"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([ 2.,  5.,  8., 11., 14., 17.])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_rwqDG6l1EE-",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "b725a7a8-6780-4121-f81c-2aa6d0ffdfca"
      },
      "source": [
        "def corr1d_multi_in(X, K):\n",
        "    return torch.stack([corr1d(x, k) for x, k in zip(X, K)]).sum(dim=0)\n",
        "\n",
        "X = torch.tensor([[0, 1, 2, 3, 4, 5, 6], \n",
        "         [1, 2, 3, 4, 5, 6, 7], \n",
        "         [2, 3, 4, 5, 6, 7, 8]])\n",
        "K = torch.tensor([[1, 2], [3, 4], [-1, -3]])\n",
        "corr1d_multi_in(X, K)"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([ 2.,  8., 14., 20., 26., 32.])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fGffQIv817sy",
        "colab_type": "text"
      },
      "source": [
        "### 10.8.2 时序最大池化层"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0RJPq9FB2HLp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class GlobalMaxPool1d(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(GlobalMaxPool1d, self).__init__()\n",
        "    \n",
        "    def forward(self, x):\n",
        "        return F.max_pool1d(x, kernel_size=x.shape[2])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vkcmXvc_2yvz",
        "colab_type": "text"
      },
      "source": [
        "### 10.8.3 读取和预处理 IMDb 数据集"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7IkG3id93BbG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!mkdir ../data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2dZM4-cV4zxg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install mxnet"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JNmn6yYL41Q7",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from mxnet.gluon import utils as gutils"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jWVYj3_u4EyR",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "DATA_ROOT = \"../data\""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZVZgiyCr481I",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def download_imdb(data_dir='../data'):\n",
        "    url = ('http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz')\n",
        "    sha1 = '01ada507287d82875905620988597833ad4e0903'\n",
        "    fname = gutils.download(url, data_dir, sha1_hash=sha1)\n",
        "    with tarfile.open(fname, 'r') as f:\n",
        "        f.extractall(data_dir)\n",
        "\n",
        "download_imdb()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ab_1_zEZ497i",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def read_imdb(folder='train'):  \n",
        "    data = []\n",
        "    for label in ['pos', 'neg']:\n",
        "        folder_name = os.path.join('../data/aclImdb/', folder, label)\n",
        "        for file in os.listdir(folder_name):\n",
        "            with open(os.path.join(folder_name, file), 'rb') as f:\n",
        "                review = f.read().decode('utf-8').replace('\\n', '').lower()\n",
        "                data.append([review, 1 if label == 'pos' else 0])\n",
        "    random.shuffle(data)\n",
        "    return data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OG8E5cJe5knp",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84
        },
        "outputId": "34828051-d6b2-4d12-daca-3cc2dad61a82"
      },
      "source": [
        "batch_size = 64 \n",
        "train_data = d2l.read_imdb('train', data_root=os.path.join(DATA_ROOT, 'aclImdb'))\n",
        "test_data = d2l.read_imdb('test', data_root=os.path.join(DATA_ROOT, 'aclImdb'))\n",
        "vocab = d2l.get_vocab_imdb(train_data)\n",
        "train_set = Data.TensorDataset(*d2l.preprocess_imdb(train_data, vocab))\n",
        "test_set = Data.TensorDataset(*d2l.preprocess_imdb(test_data, vocab))\n",
        "train_iter = Data.DataLoader(train_set, batch_size, shuffle=True)\n",
        "test_iter = Data.DataLoader(test_set, batch_size)"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 12500/12500 [00:00<00:00, 44420.38it/s]\n",
            "100%|██████████| 12500/12500 [00:00<00:00, 45405.60it/s]\n",
            "100%|██████████| 12500/12500 [00:00<00:00, 47225.96it/s]\n",
            "100%|██████████| 12500/12500 [00:00<00:00, 45261.04it/s]\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JwJncW5r6C0V",
        "colab_type": "text"
      },
      "source": [
        "### 10.8.4 textCNN 模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4WKetSlt6enu",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class TextCNN(nn.Module):\n",
        "    def __init__(self, vocab, embed_size, kernel_sizes, num_channels):\n",
        "        super(TextCNN, self).__init__()\n",
        "        self.embedding = nn.Embedding(len(vocab), embed_size)\n",
        "        self.constant_embedding = nn.Embedding(len(vocab), embed_size)\n",
        "        self.dropout = nn.Dropout(0.5)\n",
        "        self.decoder = nn.Linear(sum(num_channels), 2)\n",
        "        self.pool = GlobalMaxPool1d()\n",
        "        self.convs = nn.ModuleList()\n",
        "        for c, k in zip(num_channels, kernel_sizes):\n",
        "            self.convs.append(nn.Conv1d(in_channels = 2*embed_size, \n",
        "                          out_channels = c, \n",
        "                          kernel_size = k))\n",
        "\n",
        "    def forward(self, inputs):\n",
        "        embeddings = torch.cat((self.embedding(inputs), self.constant_embedding(inputs)), dim=2)\n",
        "        embeddings = embeddings.permute(0, 2, 1)\n",
        "        encoding = torch.cat([self.pool(F.relu(conv(embeddings))).squeeze(-1) for conv in self.convs], dim=1)\n",
        "        outputs = self.decoder(self.dropout(encoding))\n",
        "        return outputs "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pXeva6GFF3IV",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "embed_size, kernel_sizes, nums_channels = 100, [3, 4, 5], [100, 100, 100]\n",
        "net = TextCNN(vocab, embed_size, kernel_sizes, nums_channels)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n2i9UK_vIoqN",
        "colab_type": "text"
      },
      "source": [
        "#### 1 加载预训练的词向量"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DX8aDjQxIvWl",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "2a4f86be-0377-41f0-d505-284b0d5106a1"
      },
      "source": [
        "glove_vocab = Vocab.GloVe(name='6B', dim=100, cache=os.path.join(DATA_ROOT, 'glove'))"
      ],
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "../data/glove/glove.6B.zip: 862MB [06:30, 2.21MB/s]                          \n",
            "100%|█████████▉| 398098/400000 [00:20<00:00, 18741.06it/s]"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1vp58aplI8Et",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "net.embedding.weight.data.copy_(d2l.load_pretrained_embedding(vocab.itos, glove_vocab))\n",
        "net.constant_embedding.weight.data.copy_(d2l.load_pretrained_embedding(vocab.itos, glove_vocab))\n",
        "net.constant_embedding.weight.requires_grad = False "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Hlmor6rdJgUK",
        "colab_type": "text"
      },
      "source": [
        "#### 2 训练并评价模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6ks2XwV5JkUx",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 118
        },
        "outputId": "5b286516-80e4-4731-ad71-354ae98a6d4a"
      },
      "source": [
        "lr, num_epochs = 0.001, 5 \n",
        "optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)\n",
        "loss = nn.CrossEntropyLoss()\n",
        "d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)"
      ],
      "execution_count": 34,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "training on  cuda\n",
            "epoch 1, loss 0.4310, train acc 0.807, test acc 0.861, time 16.2 sec\n",
            "epoch 2, loss 0.1223, train acc 0.901, test acc 0.869, time 16.2 sec\n",
            "epoch 3, loss 0.0449, train acc 0.950, test acc 0.853, time 16.2 sec\n",
            "epoch 4, loss 0.0178, train acc 0.975, test acc 0.856, time 16.2 sec\n",
            "epoch 5, loss 0.0083, train acc 0.986, test acc 0.858, time 16.2 sec\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yFpjCf1SJ8jJ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "4a61b711-4e53-47d3-f51d-ad2393baf935"
      },
      "source": [
        "d2l.predict_sentiment(net, vocab, ['this', 'movie', 'is', 'so', 'great'])"
      ],
      "execution_count": 35,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'positive'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 35
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_OSc_H8oKK7y",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "093b02d3-6972-46d1-a5bc-886b13757295"
      },
      "source": [
        "d2l.predict_sentiment(net, vocab, ['this', 'movie', 'is', 'so', 'bad'])"
      ],
      "execution_count": 36,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'negative'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 36
        }
      ]
    }
  ]
}